#!/usr/bin/env python3

from __future__ import annotations

import argparse
import gzip
import itertools
import json
import multiprocessing
import os
import shlex
import subprocess
import sys
import xml.etree.ElementTree as ET
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    BinaryIO,
    ClassVar,
    Iterable,
    Iterator,
    NoReturn,
    Protocol,
    Sequence,
    TextIO,
    TypeVar,
    cast,
)

# TODO: import unconditionally Self from typing once we raise Python requirement to 3.11+
if TYPE_CHECKING:
    from typing_extensions import Self

WILDCARD = "*"

DEFAULT_XKB_ROOT = Path("@XKB_CONFIG_ROOT@")

# Meson needs to fill this in so we can call the tool in the buildir.
EXTRA_PATH = "@MESON_BUILD_ROOT@"
os.environ["PATH"] = ":".join(filter(bool, (EXTRA_PATH, os.getenv("PATH"))))

# Environment variable to get the right level of log
os.environ["XKB_LOG_LEVEL"] = "warning"
os.environ["XKB_LOG_VERBOSITY"] = "10"


@dataclass
class RMLVO:
    DEFAULT_RULES: ClassVar[str] = "evdev"
    DEFAULT_MODEL: ClassVar[str] = "pc105"
    DEFAULT_LAYOUT: ClassVar[str] = "us"
    rules: str
    model: str
    layout: str
    variant: str | None
    option: str | None

    def __iter__(self) -> Iterator[str | None]:
        yield self.rules
        yield self.model
        yield self.layout
        yield self.variant
        yield self.option

    @property
    def rmlvo(self) -> Iterator[tuple[str, str]]:
        yield ("rules", self.rules)
        yield ("model", self.model)
        yield ("layout", self.layout)
        # Keep only defined and non-empty values
        if self.variant is not None:
            yield ("variant", self.variant)
        if self.option is not None:
            yield ("option", self.option)

    @classmethod
    def from_rmlvo(
        cls,
        rules: str | None = None,
        model: str | None = None,
        layout: str | None = None,
        variant: str | None = None,
        option: str | None = None,
    ) -> Self:
        return cls(
            # We need to force a value for RML components
            rules or cls.DEFAULT_RULES,
            model or cls.DEFAULT_MODEL,
            layout or cls.DEFAULT_LAYOUT,
            variant,
            option,
        )


@dataclass
class Invocation(RMLVO, metaclass=ABCMeta):
    exitstatus: int = 77  # default to “skipped”
    error: str | None = None
    keymap: bytes = b""
    command: str = ""  # The fully compiled keymap

    def __str_iter(self) -> Iterator[str]:
        yield f"- rmlvo: {self.to_yaml(self.rmlvo)}"
        yield f"  cmd: {json.dumps(self.command)}"
        yield f"  status: {self.exitstatus}"
        if self.error:
            yield f"  error: {json.dumps(self.error.strip())}"

    def __str__(self) -> str:
        return "\n".join(self.__str_iter())

    @property
    def short(self) -> Iterator[tuple[str, str | int]]:
        yield from self.rmlvo
        yield ("status", self.exitstatus)
        if self.error is not None:
            yield ("error", self.error)

    @staticmethod
    def to_yaml(xs: Iterable[tuple[str, str | int]]) -> str:
        fields = ", ".join(f"{k}: {json.dumps(v)}" for k, v in xs)
        return f"{{ {fields} }}"

    def _write(self, fd: BinaryIO) -> None:
        fd.write(f"// {self.to_yaml(self.rmlvo)}\n".encode("utf-8"))
        fd.write(self.keymap)

    def _write_keymap(self, output_dir: Path, compress: int) -> None:
        layout = self.layout
        if self.variant:
            layout += f"({self.variant})"
        if self.option:
            layout += f"+{self.option}"
        (output_dir / self.model).mkdir(exist_ok=True)
        keymap_file = output_dir / self.model / layout
        if compress:
            keymap_file = keymap_file.with_suffix(".gz")
            with gzip.open(keymap_file, "wb", compresslevel=compress) as fd:
                self._write(fd)
                fd.close()
        else:
            with keymap_file.open("wb") as fd:
                self._write(fd)

    def _print_result(self, short: bool, verbose: bool) -> None:
        if self.exitstatus != 0:
            target = sys.stderr
        else:
            target = sys.stdout if verbose else None

        if target:
            if short:
                print("-", self.to_yaml(self.short), file=target)
            else:
                print(self, file=target)

    @classmethod
    @abstractmethod
    def run(
        cls,
        i: Self,
        xkb_root: Path,
        output_dir: Path | None,
        compress: int,
        *args,
        **kwargs,
    ) -> Self: ...

    @classmethod
    def run_all(
        cls,
        xkb_root: Path,
        combos: Iterable[Self],
        combos_count: int,
        njobs: int,
        keymap_output_dir: Path | None,
        verbose: bool,
        short: bool,
        progress_bar: ProgressBar[Iterable[Self]],
        chunksize: int,
        compress: int,
        **kwargs,
    ) -> bool:
        if keymap_output_dir:
            try:
                keymap_output_dir.mkdir(parents=True)
            except FileExistsError as e:
                print(e, file=sys.stderr)
                return False

        failed = False
        with multiprocessing.Pool(njobs) as p:
            f = partial(
                cls.run,
                xkb_root=xkb_root,
                output_dir=keymap_output_dir,
                compress=compress,
            )
            results = p.imap_unordered(f, combos, chunksize=chunksize)
            for invocation in progress_bar(
                results, total=combos_count, file=sys.stdout
            ):
                if invocation.exitstatus != 0:
                    failed = True

                invocation._print_result(short, verbose)

        return failed


@dataclass
class XkbCompInvocation(Invocation):
    @classmethod
    def run(
        cls,
        i: Self,
        xkb_root: Path,
        output_dir: Path | None,
        compress: int,
        *args,
        **kwargs,
    ) -> Self:
        i._run(xkb_root, not output_dir)
        if output_dir:
            i._write_keymap(output_dir, compress)
        return i

    def _run(self, xkb_root: Path, test: bool) -> None:
        setxkbmap_args = (
            "setxkbmap",
            "-print",
            # Informative only; we set CWD to ensure proper rules are loaded
            "-I",
            str(xkb_root),
            *itertools.chain.from_iterable((f"-{k}", v) for k, v in self.rmlvo),
        )
        xkbcomp_args = ("xkbcomp", "-I", f"-I{xkb_root}", "-xkb", "-", "-")

        self.command = shlex.join(itertools.chain(setxkbmap_args, "|", xkbcomp_args))

        setxkbmap = subprocess.Popen(
            setxkbmap_args,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            universal_newlines=True,
            cwd=xkb_root,
        )
        stdout, stderr = setxkbmap.communicate()
        if "Cannot open display" in stderr:
            self.error = stderr
            self.exitstatus = 90
        else:
            xkbcomp = subprocess.Popen(
                xkbcomp_args,
                stdin=subprocess.PIPE,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                universal_newlines=True,
            )
            stdout, stderr = xkbcomp.communicate(stdout)
            if xkbcomp.returncode != 0:
                self.error = (
                    "failed to compile keymap:\n"
                    f"------\nstderr:\n{stderr}\n"
                    f"------\nstdout:\n{stdout}\n"
                )
                self.exitstatus = xkbcomp.returncode
            else:
                self.keymap = stdout.encode("utf-8")
                self.exitstatus = 0


@dataclass
class XkbCompToXkbcommonInvocation(XkbCompInvocation):
    """
    Use setxkbcomp & xkbcomp, then pipe the result to xkbcli in order to enable
    comparison with direct xkbcli compilation with the same format.
    """

    def _run(self, xkb_root: Path, test: bool) -> None:
        super()._run(xkb_root, test)

        if self.exitstatus:
            return

        if test:
            return

        args = (
            "xkbcli-compile-keymap",  # this is run in the builddir
            # Not used: keymap is already compiled
            # "--include", xkb_root,
            # Not needed, because we set XKB_LOG_LEVEL and XKB_LOG_VERBOSITY in env
            # "--verbose",
            "--keymap",
        )

        try:
            completed = subprocess.run(
                args,
                check=True,
                capture_output=True,
                input=self.keymap,
            )
        except subprocess.CalledProcessError as err:
            self.error = (
                "failed to compile keymap:\n"
                f"------\nstderr:\n{err.stderr}\n"
                f"------\nstdout:\n{err.stdout}\n"
                f"------\nstdin:\n{self.keymap.decode('utf-8')}\n"
            )
            self.exitstatus = err.returncode
        else:
            self.keymap = completed.stdout


@dataclass
class XkbcommonInvocation(Invocation):
    UNRECOGNIZED_KEYSYM_ERROR: ClassVar[str] = "XKB-107"

    def _check_stderr(self, stderr: str) -> bool:
        if self.UNRECOGNIZED_KEYSYM_ERROR in stderr:
            for line in stderr.splitlines():
                if self.UNRECOGNIZED_KEYSYM_ERROR in line:
                    self.error = line
                    break
            self.exitstatus = 99  # tool doesn't generate this one
            return False
        else:
            self.exitstatus = 0
            return True

    @classmethod
    def run(
        cls,
        i: Self,
        xkb_root: Path,
        output_dir: Path | None,
        compress: int,
        *args,
        **kwargs,
    ) -> Self:
        i._run(xkb_root, output_dir, not output_dir, compress)
        return i

    def _run(
        self, xkb_root: Path, output_dir: Path | None, test: bool, compress: int
    ) -> None:
        args = (
            "xkbcli-compile-keymap",  # this is run in the builddir
            "--include",
            str(xkb_root),
            # Not needed, because we set XKB_LOG_LEVEL and XKB_LOG_VERBOSITY in env
            # "--verbose",
            *itertools.chain.from_iterable((f"--{k}", v) for k, v in self.rmlvo),
        )
        if test:
            args += ("--test",)
        self.command = shlex.join(args)
        try:
            completed = subprocess.run(args, text=True, check=True, capture_output=True)
        except subprocess.CalledProcessError as err:
            self.error = (
                "failed to compile keymap:\n"
                f"------\nstderr:\n{err.stderr}\n"
                f"------\nstdout:\n{err.stdout}\n"
            )
            self.exitstatus = err.returncode
        else:
            if self._check_stderr(completed.stderr):
                self.keymap = completed.stdout.encode("utf-8")
        if output_dir:
            self._write_keymap(output_dir, compress)


@dataclass
class XkbcommonToXkbcompInvocation(XkbcommonInvocation):
    """
    Use xkbcli, then pipe the result to xkbcomp in order to check
    that xkbcomp can parse xkbcommon output.
    """

    def _run(
        self, xkb_root: Path, output_dir: Path | None, test: bool, compress: int
    ) -> None:
        super()._run(xkb_root, output_dir, False, compress)

        if self.exitstatus:
            return

        args = ("xkbcomp", "-xkb", "-opt", "g", "-", "-")

        try:
            completed = subprocess.run(
                args,
                check=True,
                capture_output=True,
                input=self.keymap,
            )
        except subprocess.CalledProcessError as err:
            self.error = (
                "failed to compile keymap:\n"
                f"------\nstderr:\n{err.stderr}\n"
                f"------\nstdout:\n{err.stdout}\n"
                f"------\nstdin:\n{self.keymap.decode('utf-8')}\n"
            )
            self.exitstatus = err.returncode
        else:
            self.keymap = completed.stdout
        if output_dir:
            self._write_keymap(output_dir, compress)


@dataclass
class Layout:
    name: str
    variants: list[str | None]

    @classmethod
    def parse(cls, e: ET.Element, variant: list[str] | None = None) -> Self:
        if (name_elem := e.find("configItem/name")) is None or name_elem is None:
            raise ValueError("Layout name not found")
        if not variant:
            variants = [None] + [
                cls.parse_text(v)
                for v in e.findall("variantList/variant/configItem/name")
            ]
        else:
            variants = cast(list[str | None], variant)
        return cls(cls.parse_text(e.find("configItem/name")), variants)

    @staticmethod
    def parse_text(e: ET.Element | None) -> str:
        if e is None or not e.text:
            raise ValueError("Name not found")
        return e.text


class Registry:
    @classmethod
    def parse_path(cls, xkb_root: Path, path: Path) -> Path:
        if path.is_file():
            # File exists: return unchanged
            return path
        elif len(path.parts) == 1:
            # Lookup XML file in XKB root
            _path = (
                path
                if path.suffix == ".xml"
                # NOTE: If we got evdev.extras, we want to keep the current suffix
                else path.with_suffix(f"{path.suffix}.xml")
            )
            _path = xkb_root / "rules" / _path
            if _path.is_file():
                return _path
        raise ValueError(f"Cannot resolve registry file: {path}")

    @classmethod
    def parse(
        cls,
        paths: Sequence[Path],
        tool: type[Invocation],
        rules: str | None,
        model: str | None,
        layout: str | None,
        variant: str | None,
        option: str | None,
    ) -> tuple[int, Iterator[Invocation]]:
        models: tuple[str, ...] = ()
        layouts: tuple[Layout, ...] = ()
        options: tuple[str, ...] = ()

        if variant and not layout:
            raise ValueError("Variant must be set together with layout")

        for path in paths:
            root = ET.fromstring(path.read_text(encoding="utf-8"))

            # Models
            if model is None:
                models += tuple(
                    e.text
                    for e in root.findall("modelList/model/configItem/name")
                    if e.text
                )
            elif not models:
                models += (model,)

            # Layouts/variants
            if layout:
                if variant is None:
                    layouts += tuple(
                        map(
                            Layout.parse,
                            (
                                e
                                for e in root.findall("layoutList/layout")
                                if e.find(f"configItem/name[.='{layout}']") is not None
                            ),
                        )
                    )
                elif not layouts:
                    layouts += (
                        Layout(layout, cast(list[str | None], variant.split(":"))),
                    )
            else:
                layouts += tuple(map(Layout.parse, root.findall("layoutList/layout")))

            # Options
            if option is None:
                options += tuple(
                    e.text
                    for e in root.findall("optionList/group/option/configItem/name")
                    if e.text
                )
            elif not options and option:
                options += (option,)

        # Some registry may be only partial, e.g.: *.extras.xml
        if not models:
            models = (RMLVO.DEFAULT_MODEL,)
        if not layouts:
            layouts = (Layout(RMLVO.DEFAULT_LAYOUT, [None]),)

        count = len(models) * sum(len(l.variants) for l in layouts) * (1 + len(options))

        # The list of combos can be huge, so better to use a generator instead
        def iter_combos() -> Iterator[Invocation]:
            for m in models:
                for l in layouts:
                    for v in l.variants:
                        yield tool.from_rmlvo(
                            rules=rules, model=m, layout=l.name, variant=v, option=None
                        )
                        for opt in options:
                            yield tool.from_rmlvo(
                                rules=rules,
                                model=m,
                                layout=l.name,
                                variant=v,
                                option=opt,
                            )

        return count, iter_combos()


T = TypeVar("T")


# Needed because Callable does not handle keywords args
class ProgressBar(Protocol[T]):
    def __call__(self, x: T, total: int, file: TextIO | None) -> T: ...


# The function generating the progress bar (if any).
def create_progress_bar(verbose: bool) -> ProgressBar[T]:
    def noop_progress_bar(x: T, total: int, file: TextIO | None = None) -> T:
        return x

    progress_bar: ProgressBar[T] = noop_progress_bar
    if not verbose and os.isatty(sys.stdout.fileno()):
        try:
            from tqdm import tqdm

            progress_bar = cast(ProgressBar[T], tqdm)
        except ImportError:
            pass

    return progress_bar


def main() -> NoReturn:
    parser = argparse.ArgumentParser(
        description="""
                    This tool compiles a keymap for each layout, variant and
                    options combination in the given rules XML file. The output
                    of this tool is YAML, use your favorite YAML parser to
                    extract error messages. Errors are printed to stderr.
                    """
    )
    parser.add_argument(
        "paths",
        metavar="/path/to/rules.xml",
        nargs="*",
        type=Path,
        default=(),
        help="Path to xkeyboard-config's XML registry file",
    )
    parser.add_argument(
        "--xkb-root",
        default=DEFAULT_XKB_ROOT,
        type=Path,
        help="XKB root directory",
    )
    DEFAULT_TOOL = "xkbcommon"
    tools: dict[str, type[Invocation]] = {
        DEFAULT_TOOL: XkbcommonInvocation,
        "xkbcommon-xkbcomp": XkbcommonToXkbcompInvocation,
        "xkbcomp": XkbCompInvocation,
        "xkbcomp-xkbcommon": XkbCompToXkbcommonInvocation,
    }
    parser.add_argument(
        "--tool",
        choices=tools.keys(),
        type=str,
        default=DEFAULT_TOOL,
        help="parsing tool to use",
    )
    parser.add_argument(
        "--jobs",
        "-j",
        type=int,
        default=4 * (os.cpu_count() or 1),
        help="number of processes to use",
    )
    parser.add_argument("--chunksize", default=1, type=int)
    parser.add_argument("--verbose", "-v", default=False, action="store_true")
    parser.add_argument(
        "--short", default=False, action="store_true", help="Concise output"
    )
    parser.add_argument(
        "--keymap-output-dir",
        default=None,
        type=Path,
        help="Directory to print compiled keymaps to",
    )
    parser.add_argument(
        "--compress", type=int, default=0, help="Compression level of keymaps files"
    )
    parser.add_argument(
        "--rules", default=RMLVO.DEFAULT_RULES, type=str, help="Rule set to use"
    )
    parser.add_argument(
        "--model", default="", type=str, help="Only test the given model"
    )
    parser.add_argument(
        "--layout", default=WILDCARD, type=str, help="Only test the given layout"
    )
    parser.add_argument(
        "--variant",
        default=WILDCARD,
        type=str,
        help="Only test the given variants (colon-separated list)",
    )
    parser.add_argument(
        "--option", default=WILDCARD, type=str, help="Only test the given option"
    )
    parser.add_argument(
        "--no-iterations", "-1", action="store_true", help="Only test one combo"
    )

    args = parser.parse_args()

    xkb_root: Path = args.xkb_root
    verbose: bool = args.verbose
    short = args.short
    keymapdir = args.keymap_output_dir
    progress_bar: ProgressBar[Iterable[Invocation]] = create_progress_bar(verbose)

    tool = tools[args.tool]

    # NOTE: We test only one set of rules; handle wild card only for consistency
    #       with other components.
    rules: str | None = None if args.rules == WILDCARD else args.rules
    model: str | None = None if args.model == WILDCARD else args.model
    layout: str | None = None if args.layout == WILDCARD else args.layout
    variant: str | None = None if args.variant == WILDCARD else args.variant
    option: str | None = None if args.option == WILDCARD else args.option

    if not args.paths:
        # If there is no given registry, fallback to the given rules,
        # else fallback to the default rules.
        path = xkb_root / "rules" / (rules or RMLVO.DEFAULT_RULES)
        paths = (path.with_suffix(f"{path.suffix}.xml"),)
    else:
        paths = tuple(Registry.parse_path(xkb_root, p) for p in args.paths)

    # This need to be valid YAML. Currently unused, so left as comments
    print("# xkb root:", json.dumps(str(xkb_root)), file=sys.stderr)
    print("# paths:", json.dumps(tuple(map(str, paths))), file=sys.stderr)

    if args.no_iterations:
        combos = (
            tool.from_rmlvo(
                rules=rules, model=model, layout=layout, variant=variant, option=option
            ),
        )
        count = len(combos)
        iter_combos = iter(combos)
    else:
        count, iter_combos = Registry.parse(
            paths, tool, rules, model, layout, variant, option
        )

    failed = tool.run_all(
        xkb_root=xkb_root,
        combos=iter_combos,
        combos_count=count,
        njobs=args.jobs,
        keymap_output_dir=keymapdir,
        verbose=verbose,
        short=short,
        progress_bar=progress_bar,
        chunksize=args.chunksize,
        compress=args.compress,
    )
    sys.exit(failed)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("# Exiting after Ctrl+C")
